{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d658f909-e679-41e9-9c4e-e0241c719049",
   "metadata": {},
   "source": [
    "If you're not running in Saturn Cloud, you need to install these libraries:\n",
    "\n",
    "Make sure you use the latest versions\n",
    "\n",
    "```\n",
    "pip install -U transformers accelerate bitsandbytes\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6744a00f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_HOME'] = '/run/cache/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "506fab2a-a50c-42bd-a106-c83a9d2828ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-06-13 19:36:42--  https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 3832 (3.7K) [text/plain]\n",
      "Saving to: ‘minsearch.py’\n",
      "\n",
      "minsearch.py        100%[===================>]   3.74K  --.-KB/s    in 0s      \n",
      "\n",
      "2024-06-13 19:36:42 (50.3 MB/s) - ‘minsearch.py’ saved [3832/3832]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!rm -f minsearch.py\n",
    "!wget https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ac947de-effd-4b61-8792-a6d7a133f347",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<minsearch.Index at 0x7fc130288970>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import requests \n",
    "import minsearch\n",
    "\n",
    "docs_url = 'https://github.com/DataTalksClub/llm-zoomcamp/blob/main/01-intro/documents.json?raw=1'\n",
    "docs_response = requests.get(docs_url)\n",
    "documents_raw = docs_response.json()\n",
    "\n",
    "documents = []\n",
    "\n",
    "for course in documents_raw:\n",
    "    course_name = course['course']\n",
    "\n",
    "    for doc in course['documents']:\n",
    "        doc['course'] = course_name\n",
    "        documents.append(doc)\n",
    "\n",
    "index = minsearch.Index(\n",
    "    text_fields=[\"question\", \"text\", \"section\"],\n",
    "    keyword_fields=[\"course\"]\n",
    ")\n",
    "\n",
    "index.fit(documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "8f087272-b44d-4738-9ea2-175ec63a058b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def search(query):\n",
    "    boost = {'question': 3.0, 'section': 0.5}\n",
    "\n",
    "    results = index.search(\n",
    "        query=query,\n",
    "        filter_dict={'course': 'data-engineering-zoomcamp'},\n",
    "        boost_dict=boost,\n",
    "        num_results=3\n",
    "    )\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe8bff3e-b672-42be-866b-f2d9bb217106",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rag(query):\n",
    "    search_results = search(query)\n",
    "    prompt = build_prompt(query, search_results)\n",
    "    answer = llm(prompt)\n",
    "    return answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "07d4639a",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['HF_TOKEN'] = 'hf_blabla'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2b69a481",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6650945d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
      "Token is valid (permission: read).\n",
      "Your token has been saved to /run/cache/token\n",
      "Login successful\n"
     ]
    }
   ],
   "source": [
    "login(token=os.environ['HF_TOKEN'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cfe5b58d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "311fd41d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95ce1023e96346768985fad0ffedbaad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab9ccd09e5c546c18678fab88a4b4536",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "037d21f6a6b9419ca83f5ead1cd22f85",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4fc2ef41f8414b77a8f833b8e1e794ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5daa1086719b416cade75549effb2337",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "63b1c033ad86405eb7455989951ba2df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0dc66a7b147246f28eacf34998eb0c4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28ea347a775e4b7d91e0923bfc89b697",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17b14dc41f66482abd9efe217be69872",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df30ec1fb8864ea39286b637f9a0cfcd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6613cd42eae24c86b3776f2062b19de7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"mistralai/Mistral-7B-v0.1\", device_map=\"auto\", load_in_4bit=True\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-v0.1\", padding_side=\"left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b7f6fd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "da66a082",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_prompt(query, search_results):\n",
    "    prompt_template = \"\"\"\n",
    "QUESTION: {question}\n",
    "\n",
    "CONTEXT:\n",
    "{context}\n",
    "\n",
    "ANSWER:\n",
    "\"\"\".strip()\n",
    "\n",
    "    context = \"\"\n",
    "    \n",
    "    for doc in search_results:\n",
    "        context = context + f\"{doc['question']}\\n{doc['text']}\\n\\n\"\n",
    "    \n",
    "    prompt = prompt_template.format(question=query, context=context).strip()\n",
    "    return prompt\n",
    "\n",
    "def llm(prompt):\n",
    "    response = generator(prompt, max_length=500, temperature=0.7, top_p=0.95, num_return_sequences=1)\n",
    "    response_final = response[0]['generated_text']\n",
    "    return response_final[len(prompt):].strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "41b774d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Yes, you can still join the course.'"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rag(\"I just discovered the course. Can I still join it?\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "saturn (Python 3)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
